import ekf_testv6 as ekf
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
import scipy.stats as ss
import summary_of_sims as sum_sims
from statsmodels.graphics import tsaplots
import matplotlib.transforms as mtransforms

matsh=10
eval_points=np.arange(-.25,.15,.001)
#order: running 30-year mean, OCN, Butterworth, change point, EKF, RTS, coupled sims average, blind model
#order to simply have a smaller range for things that don't apply

labels=["Running Average $\overline{_{30}Y_t}$", "OCN Chunks", "Butterworth Smoothed", "Change Point Lines", "EBM-KF-uf $\\hat {T }_t$", "EBM-KF-ta", "RTS $ \\hat \hat {T }_t$", \
        "LENS2 $\overline{(Y_t)_j}$", "Blind Model $\~{T}_{t}$"]
shlabels=["RunAvg \n $\overline{_{30}Y_t}$", "OCN \n Chunks", "ButW \n Smooth", "$\\Delta$Pt \n Lines", "EBM-KF-uf \n $\\hat {T }_t$", "EBM-KF-ta",  "RTS \n $ \\hat \hat {T }_t$", \
        "LENS2 \n $\overline{(Y_t)_j}$", "Blind \n $\~{T}_{t}$", "HadCRUT5 \n ${Y}_{t}$"]
colorlabels=["yellow","limegreen","springgreen", ekf.pcolor, ekf.colorekf,"darkolivegreen", ekf.colorrts ,sum_sims.cesmcolor,"darkgoldenrod",ekf.colorgrey]
oneswt=np.ones(ekf.n_iters)
values = np.empty((matsh,ekf.n_iters))
values[:] = np.nan
stdevs= np.empty((matsh-2,ekf.n_iters))
stdevs[:] = np.nan
moderrs=np.empty((matsh-2,ekf.n_iters))
moderrs[:] = np.nan
temps=ekf.temps

#compute running 30 yr mean
print(len(ekf.temps)==ekf.n_iters)

values[0,:]=ekf.moving_aves
stdevs[0,:]=ekf.std_aves
moderrs[0,:]=ekf.std_aves/np.sqrt(ekf.N)

#OCN
dataOCN = np.genfromtxt(open("Mann_Smoothing_08/OCNtemperatures.csv", "rb"),dtype=float, delimiter=',')
lenOCN=len(dataOCN[1,:])
values[1,:lenOCN]=dataOCN[1,:]
stdevs[1,:lenOCN]=dataOCN[3,:]/2
moderrs[1,:lenOCN]=dataOCN[2,:]

#butterworth
dataBU = np.genfromtxt(open("Mann_Smoothing_08/BUStemperatures.csv", "rb"),dtype=float, delimiter=',')
values[2,:]=dataBU[1,:]
fN=ekf.fN; cN=ekf.cN
for i in range((fN),(len(temps)-cN+1)):
    lasta=i+cN;firsta=i-fN
    stdevs[2,i] = 2*np.mean(np.abs(dataBU[3,firsta:lasta]));
    moderrs[2,i]= 2*np.mean(dataBU[2,firsta:lasta]);

#change point
dataCP = np.genfromtxt(open("Bayes_Sequential_Change_Point/BSCtemperatures.csv", "rb"),dtype=float, delimiter=',')
values[3,:]=dataCP[1,:]
stdevs[3,:]=dataCP[3,:]/2
moderrs[3,:]=dataCP[2,:]

#ekf
values[4,:]=np.copy(ekf.xh1s)
stdevs[4,:]=np.copy(ekf.stdS)
moderrs[4,:]=np.copy(ekf.stdP)
#stdevs[4,0:2]=np.nan

#rts
values[6,:]=np.copy(ekf.xhh1s)
stdevs[6,:]=np.copy(ekf.stdSh)
moderrs[6,:]=np.copy(ekf.stdPh)
stdevs[6,-1]=np.nan
#stdevs[5,0:1]=np.nan

#simulations
#lenss=sum_sims.smyrs
lenss1=174
values[7,:lenss1]=sum_sims.twTRmean[:lenss1]
stdevs[7,:lenss1]=sum_sims.twTRstd[:lenss1]
lenss=165
CMIP6=np.genfromtxt(open("summaryCMIP6.csv", "rb"),dtype=float, delimiter=',')
moderrs[7,:lenss]=CMIP6[20,:]
#blind model
values[8,:]=ekf.xblind[:,0]

values[9,:]=temps

wt_opt_depths = 1/(ekf.opt_depth+9.7279)
N = 30
nwt_opt_depths=np.empty(len(ekf.opt_depth)); nwt_opt_depths[:]=ekf.involcavg
cN=int(np.ceil(N/2))
fN=int(np.floor(N/2))
for i in range((fN),(len(nwt_opt_depths)-1)):
    lasta=i+cN;firsta=i-fN
    nwt_opt_depths[i] = (np.sum(wt_opt_depths[(firsta):i+1])+ ekf.involcavg*(cN-1))/N
    #computing half-average - future is assumed to be the average 
nopt_depths=(1/nwt_opt_depths-9.7279)
ekf.opt_depth=nopt_depths
(trailfilter_xhat,P2s,S2s,trailfilter_xhatm)=ekf.ekf_run(ekf.observ,ekf.n_iters,retPs=2)
xh2s=trailfilter_xhat[:,0]
stdP2s=np.sqrt(np.abs(P2s))[:,0,0]
values[5,:]=xh2s
stdevs[5,:]=stdP2s
moderrs[5,:]=np.sqrt(np.abs(S2s))[:,0,0]

#absolute difference from real temps
figh = plt.figure(figsize=(12,4))
#figh,(axl,axp)=plt.subplots(1,2,figsize=(12,4))
#figh.subplots_adjust(wspace=0.4)
gs = figh.add_gridspec(1, 4,  width_ratios=(4,.7,4,1),
                      left=0.1, right=0.9, bottom=0.1, top=0.9,
                      wspace=0.2, hspace=0.05)


axl=figh.add_subplot(gs[0, 0])
axp=figh.add_subplot(gs[0, 2])
axp.set_ylim([-0.225,0.1])
ax_histy = figh.add_subplot(gs[0, 3], sharey=axp)
binwidth = 0.005
xymax = 0.21
lim = (int(xymax/binwidth) + 1) * binwidth
bins = np.arange(-lim, lim + binwidth, binwidth)

diffptsEKF=-values[4,:-1]+temps[0:-1]
derivptsEKF=values[4,1:]-values[4,:-1]

diffpts30m=-values[0,15:-16]+temps[15:-16]
derivpts30m= values[0,16:-15]-values[0,15:-16]
axp.plot(ekf.dates[1:],derivptsEKF,'.',color=colorlabels[4],label="EBM-KF-uf",zorder=2, markersize=6)
#ax_histy.hist(derivptsEKF[1:], bins=bins, histtype=u'step',orientation='horizontal',color=colorlabels[4],zorder=2)
ax_histy.plot(ss.gaussian_kde(derivptsEKF[1:]).pdf(eval_points),eval_points,color=colorlabels[4])
axp.plot(ekf.dates[16:-15],derivpts30m,'.',color='yellow', markeredgecolor=colorlabels[1],label="30-year running mean",zorder=3,alpha=1, markersize=6,markeredgewidth=1)
#ax_histy.hist(derivpts30m, bins=bins,orientation='horizontal',color=colorlabels[0],zorder=1,alpha=0.5)
deriv30hist = ss.gaussian_kde(derivpts30m).pdf(eval_points)
ax_histy.plot(deriv30hist,eval_points,color=colorlabels[1])
ax_histy.plot(deriv30hist,eval_points,'--',color='yellow',linewidth=0.8)
##axp.set_xlabel("Innovation from Mean/State to Measurement")
axp.set_ylabel("Change in Temperature State (°C/yr)")
axp.set_title("Comparison of Annual Temperature Changes")
##
##legend1=plt.legend(loc="lower right")
##
noteYears = [[1991,1992,1993,1994,1995,1996,1997,1998],[1963,1964,1965],[1902,1903,1904],[1883,1884,1885],
		[1856,1857,1858]]
noteVolcs = ["Mt. Pinatubo, \nPhilippines", "Mt. Agung, \nIndonesia","Santa Maria, \nGuatemala", "Krakatoa, \nIndonesia", \
              "Mt. Awu, \nIndonesia"] # "Komaga-take, Japan" "Shiveluch, Russia"
symbols = ['^', '>', 's','P','*','X']
##volplots=[]
##for i in range(len(noteVolcs)):
##    ds=np.array(noteYears[i])-1850-1
##    volplot,=axp.plot(diffptsEKF[ds],derivptsEKF[ds],marker=symbols[i],linestyle='None',color='0.7',zorder=0)
##    volplots.append(volplot)
##    #axp.plot(diffpts30m[ds-15],derivpts30m[ds-15],marker=symbols[i],linestyle='None',color='k',zorder=0)
##    for d in ds:
##        if(d+1850+1==1902):
##            xshf=-5
##        elif(d+1850+1==1963):
##            xshf=-10
##        elif(d+1850+1==1998):
##            xshf=3
##        else:
##            xshf=0
##        plt.annotate((d+1850+1), (diffptsEKF[d],derivptsEKF[d]),fontsize=8,font='Arial Narrow',
##                 textcoords="offset points", xytext=(xshf,-3+6*np.sign(derivptsEKF[d])), ha='center')
##       # plt.annotate((d), (diffpts30m[d-15],derivpts30m[d-14]),
##       #          textcoords="offset points", xytext=(0,0.01), ha='center')
##    plt.annotate("1850 - initalized", (diffptsEKF[0],derivptsEKF[0]),fontsize=8,font='Arial Narrow',
##                 textcoords="offset points", xytext=(0,-3-6), ha='center')
##legend2 = axp.legend(volplots,noteVolcs,loc="center right",fontsize=8)
##axp.add_artist(legend1)
##axp.set_ylim(-0.25,0.11)



print("MAX INDEX")
print(np.argmax(diffpts30m))
sortedderivs = sorted(zip(ekf.dates[1:],derivptsEKF), key=lambda x:x[1])
print(sortedderivs[0:20])
r2 = ekf.r2_score(values[0,15:-16], values[4,15:-16])
print('r2 score for running mean is to EBM-KF', r2)
r2 = ekf.r2_score(values[0,15:-16], values[7,15:-16])
print('r2 score for running mean is to blind', r2)
r2 = ekf.r2_score(values[0,15:-16], values[6,15:-16])
print('r2 score for running mean to LENS2 ens', r2)
r2 = ekf.r2_score(values[6,:], values[4,:])
print('r2 score for LENS2 ens to EBM-KF', r2)
r2 = ekf.r2_score(values[7,:], values[4,:])
print('??? r2 score for blind to EBM-KF', r2)

axlabels=['a)','b)','c)']

axs=(axl,axp, ax_histy)
for i in range(len(axs)):
    ax=axs[i]
    label=axlabels[i]
# label physical distance to the left and up:
    trans = mtransforms.ScaledTranslation(-40/72, 1/72, figh.dpi_scale_trans) #-20/72, 7/72
    if i==2:
        trans = mtransforms.ScaledTranslation(20/72, 1/72, figh.dpi_scale_trans)
    ax.text(0.0, 1.0, label, transform=ax.transAxes + trans,
        fontsize='large', va='bottom') #, fontfamily='serif')



#graparr=[[4,0,1,2],[4,3,6],[4,5,7]]
gs_bot = plt.GridSpec(matsh, 4, wspace=0.3, top=0.8)
gs_base = plt.GridSpec(matsh, 4, wspace=0.3, hspace=0.1)
fig = plt.figure(figsize=(12,11))
ax0s=[]
for j in range(0,4):
    botax = fig.add_subplot(gs_bot[matsh-1,j])
    other_axes = [fig.add_subplot(gs_base[i,j]) for i in range(0, matsh-1)]
    ax0s.append(other_axes+[botax])
axs=[list(x) for x in zip(*ax0s)]
#print(axs)
#fig, axs = plt.subplots(matsh, 4, figsize=(12,10))
#fig.subplots_adjust(wspace=0.3, hspace=0.2)
fig.suptitle("Histogram Comparisons of Smoothing Methods on GMST\n (Bar Height Represents Fraction of Timepoints)")
for d in range(0,matsh):
    
    ax = axs[d][0]
    if d==0:
        ax.set_title("Real HadCRUT5\n Measurements\n Minus Value")
    if d==9:
        ax.set_xticklabels([])
    if d==8:
        ax.set_xlabel("Temperature Difference (K)")
    elif d<=7:
        ax.set_xticklabels([])
        
     
    if d<=8:
        diff=-values[d,:]+temps
        count=np.count_nonzero(~np.isnan(diff))
        ax.hist(diff,weights=oneswt/count,bins=np.linspace(-0.5,0.5,50),alpha=1, color=colorlabels[d])
    #ax.legend(loc='best',prop={"size":8})
        ax.set_ylim(0.0,0.14)
        ax.set_xlim(-0.4,0.4)
        ax.tick_params( direction = 'out',bottom=True, top=False, left=True, right=True )

    else:
        ax.spines['right'].set_visible(False)
        ax.spines['top'].set_visible(False)
        ax.spines['bottom'].set_visible(False)
        ax.spines['left'].set_visible(False)
        ax.tick_params( direction = 'out',bottom=False, top=False, left=False, right=False)
        ax.set_zorder(-1)
        ax.set_yticklabels([])

#    if d<8:
#        ax.set_ylabel(shlabels[d]) 
    ax.set_ylabel(shlabels[d])
    
    ax =axs[d][1] 
    if d==0:
        ax.set_title("Yearly Value Change")
    if d==9:
        ax.set_xlabel("Temperature Change (K/year)")
    elif d!=8:
        ax.set_xticklabels([])
    #else:
       # ax.set_xticklabels([])

    diff=values[d,1:]-values[d,:-1]
    count=np.count_nonzero(~np.isnan(diff))
    ax.hist(diff,weights=oneswt[1:]/count,bins=np.linspace(-0.15,0.15,150),alpha=1, color=colorlabels[d])
    #ax.legend(loc='best',prop={"size":8})ck
    if d==1:
        ax.set_ylim(0.0,0.11) #ocn chunks too high, almost all are 0
        ax.text(0.002,0.08,"*")
    ax.set_xlim(-0.065,0.065)
    ax.set_xticks(np.arange(-0.06, 0.07, 0.03))
    if d==9:
        #ax.set_ylim(0.0,0.07)
        ax.set_xlim(-0.25,0.25)
        ax.set_xticks(np.arange(-0.2, 0.25, 0.1))
        ax.hist(diff,weights=oneswt[1:]/count,bins=np.linspace(-0.35,0.35,50),alpha=1, color=colorlabels[d])
        
    
    ax.tick_params( direction = 'out',bottom=True, top=False, left=True, right=True )
#    if d==8:
#        ax.set_ylabel(shlabels[d]) 

    ax2 =axs[d][2]
    if d!=8:
        tsaplots.plot_acf(values[d,:]-values[8,:], lags=30, ax=ax2, missing='drop', title='')
        ax2.set_ylim(-0.9,1.1)
    else:
        ax2.spines['right'].set_visible(False)
        ax2.spines['top'].set_visible(False)
        ax2.spines['bottom'].set_visible(False)
        ax2.spines['left'].set_visible(False)
        ax2.tick_params( direction = 'out',bottom=False, top=False, left=False, right=False)
        ax2.set_zorder(-1)
        ax2.set_yticklabels([])
    if d==0:
        ax2.set_title("Autocorrelation of\n Difference from Blind Model")
	
    if  (d==9 or d==7):
        ax2.set_xlabel("Lag Time (years)")
    else:
        ax2.set_xticklabels([])

    ax3=axs[d][3]
    ax3.set_ylim(0,35)
    ranflickerb=32;
    ranflicker=8;
    stepflicker=8;
    ax3.set_xticks(np.arange(-ranflickerb, ranflicker+stepflicker, step=stepflicker)) 
    
    if d==0:
        ax3.set_title("Crossing Times\n with Real HadCRUT5")
    if d==8:
        ax3.set_xlabel("String of Years\n Above(+) or Below(-)\n all years given equal weight")
    else:
        ax3.set_xticklabels([])

    if d<9:
        nonnans=~np.isnan(values[d,:])
        gted="".join(((np.greater_equal(values[d,nonnans],values[9,nonnans])).astype(int)).astype(str))
        negs=np.fromiter(map(len,gted.split('1')),dtype=int)*(-1)
        posis=np.fromiter(map(len,gted.split('0')),dtype=int)
        histo=np.concatenate((negs,posis))
        ax3.hist(histo,weights=np.abs(histo),alpha=1, bins=(ranflicker+ranflickerb)*2+1,
                 range=(-ranflickerb-0.25,ranflicker+0.25),color=colorlabels[d])
    else:
        ax3.spines['right'].set_visible(False)
        ax3.spines['top'].set_visible(False)
        ax3.spines['bottom'].set_visible(False)
        ax3.spines['left'].set_visible(False)
        ax3.tick_params( direction = 'out',bottom=False, top=False, left=False, right=False)
        ax3.set_zorder(-1)
        ax3.set_yticklabels([])
	


axl.set_title('Comparison of Temperature Means')


axl.plot(ekf.dates,xh2s-ekf.pindavg,'-',label='EBM-KF-ta', color='darkolivegreen',zorder=5)
#np.max(np.abs(xh2s[15:-16]-values[0,15:-16]))
#values[0,15:-16], xh2s[15:-16]

for i in range((fN),(len(nwt_opt_depths)-cN+1)):
    lasta=i+cN;firsta=i-fN
    nwt_opt_depths[i] = np.mean(wt_opt_depths[firsta:lasta])
    #computing half-average - future is assumed to be the average 
nopt_depths=(1/nwt_opt_depths-9.7279)
ekf.opt_depth=nopt_depths
(smoothed_xhat,P3s)=ekf.ekf_run(ekf.observ,ekf.n_iters,retPs=True)
xh3s=smoothed_xhat[:,0]



axp.plot(ekf.dates[2:],xh2s[2:]-xh2s[1:-1],'.',color='darkolivegreen',label="EBM-KF-ta")
#ax_histy.hist(xh2s[2:]-xh2s[1:-1], bins=bins,histtype=u'step', orientation='horizontal',color="darkolivegreen",zorder=2)
ax_histy.plot(ss.gaussian_kde(xh2s[2:]-xh2s[1:-1]).pdf(eval_points),eval_points,color="darkolivegreen",zorder=5)
axp.plot(ekf.dates[1:],sum_sims.twTRmean[1:lenss1]-sum_sims.twTRmean[:lenss1-1],'.',color=colorlabels[7],label=shlabels[7],zorder=0)
#ax_histy.hist(sum_sims.twTRmean[1:lenss1]-sum_sims.twTRmean[:lenss1-1], bins=bins,orientation='horizontal',color=colorlabels[6],zorder=0)
ax_histy.plot(ss.gaussian_kde(sum_sims.twTRmean[1:lenss1]-sum_sims.twTRmean[:lenss1-1]).pdf(eval_points),eval_points,color=colorlabels[7],zorder=0)
axp.set_xlabel("Year")
ax_histy.set_xlabel("# of Years")
r2 = ekf.r2_score(values[0,15:-16], xh2s[15:-16])
print('r2 score for running mean is to EBM-KF-ta', r2)

toplot=[7,4]
for i in toplot:
    axl.plot(ekf.dates,values[i,:]-ekf.pindavg,'-',label=labels[i], color=colorlabels[i])
axl.plot(ekf.dates,values[0,:]-ekf.pindavg,'-',label=labels[0], color=colorlabels[1])
axl.plot(ekf.dates,values[0,:]-ekf.pindavg,'--',label=labels[0], color="yellow",linewidth=0.8)

dotted_line1 = plt.Line2D([], [],  linestyle="--", color="yellow", linewidth=0.8)
dotted_line2 = plt.Line2D([], [],  linestyle="-", color=colorlabels[1])

handles0, labels = axl.get_legend_handles_labels()
handles0[3] = (dotted_line2, dotted_line1)
order=[0,3,2,1]
axl.legend( [handles0[i] for i in order], [labels[i] for i in order], loc='upper left')
ekf.plot_boilerplate(axl)
axp.set_xticks(np.arange(1850,2025+1,25))
axp.set_xlim(1850,2025)
shiftup=[0,0,0,0.03,0]
shift0=[-0.2, -0.2, 0.4, 0.2 ,0.4]
for vi in range(len(noteYears)):
    v=noteYears[vi]
    axl.plot([v[0],v[0]],[-0.5,1.7],':',color='0.7',zorder=0)
    axl.annotate( noteVolcs[vi], (v[0], shift0[vi]),fontsize=8)
    axp.plot([v[0],v[0]],[-0.5,1.7],':',color='0.7',zorder=0)#[286,289]
    axp.annotate( noteVolcs[vi], (v[0], -0.2+shiftup[vi]),fontsize=8)

axl.set_yticks(np.arange(-0.4,1.8,0.2))
axl.set_ylabel('$\\Delta$ Temperature (°C)\nAbs(K)=$\\Delta(T)$ + 286.7K')
axl.set_ylim(286.3-ekf.pindavg,288.1-ekf.pindavg)
axb = axl.twinx()
axb.set_yticks(np.arange(12.4,18.4,0.2))
axb.set_ylim(286.3- 273.15, 288.1- 273.15)
axb.set_ylabel('Temperature (°C)')

startpts=[15,0,  15, 0,2,  2,  0]
endpts=[-15,-2, -15,-1,-1,-2,-13]
order=[2,6,0,1,3,5,4]

plt.savefig("compare_hists1.pdf", format="pdf")    
plt.savefig("compare_hists1.png", dpi=400,format="png")

plt.figure()
for ord in range(0,7):
    d=order[ord]
    label2=shlabels[d]
    label2=label2.replace('\n','')
    if ord<=4:
        plt.plot(stdevs[d,:],moderrs[d,:],'-o',color=colorlabels[d], label=label2, alpha=0.7, markersize=2)
    else:
        plt.plot(stdevs[d,:],moderrs[d,:],'-o',color=colorlabels[d], label=label2, alpha=1, markersize=3, linewidth=2)
    plt.plot(stdevs[d,startpts[d]],moderrs[d,startpts[d]],'^',color=colorlabels[d], markersize=5 ,markeredgewidth=.5, markeredgecolor=(0,0,0, 1))
    plt.plot(stdevs[d,endpts[d]],moderrs[d,endpts[d]],'s',color=colorlabels[d], markersize=5 ,markeredgewidth=.5, markeredgecolor=(0,0,0, 1))

order2=[2,3,0,4,6,5,1]
handles, labels = plt.gca().get_legend_handles_labels()
plt.legend([handles[idx] for idx in order2],[labels[idx] for idx in order2])
plt.title("Comparison of Modeled GMST Variabilities")
plt.ylabel("State Uncertainty $\sqrt{p_{t}^T}$ or Std Error in K")
plt.xlabel("Prediction / Innovation Uncertainty $\sqrt{s_{t}^T}$ or Std Dev in K")
plt.ylim([0,0.15])
plt.xlim([0,0.23])
plt.xticks(np.arange(0, .24, 0.02))

plt.savefig("compare_hists2.pdf", format="pdf")    
plt.savefig("compare_hists2.png", dpi=400,format="png")

figh.savefig("compare_means_new.pdf", format="pdf")    
figh.savefig("compare_means_new.png", dpi=400,format="png")


plt.figure()
plt.plot(ekf.dates, 1/wt_opt_depths-9.7279, color=ekf.colorekf)
plt.plot(ekf.dates, 1/nwt_opt_depths-9.7279, color='darkolivegreen')

plt.show()





